{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qKlB5QTDIV6S"
      },
      "source": [
        "# LoRA (Sampling)\n",
        "\n",
        "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/google-deepmind/gemma/blob/main/colabs/lora_sampling.ipynb)\n",
        "\n",
        "Example on using LoRA with Gemma (for inference). For an example of fine-tuning with LoRA, see [LoRA finetuning](https://github.com/google-deepmind/gemma/blob/main/docs/lora_finetuning.md) example."
      ]
    },
    {
      "metadata": {
        "id": "TR-L25KVKT_F"
      },
      "cell_type": "code",
      "source": [
        "!pip install -q gemma"
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "I6fEKB1tISVW"
      },
      "outputs": [],
      "source": [
        "# Common imports\n",
        "import os\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import treescope\n",
        "\n",
        "# Gemma imports\n",
        "from gemma import gm\n",
        "from gemma import peft  # Parameter fine-tuning module"
      ]
    },
    {
      "metadata": {
        "id": "cxGT2XeU4L47"
      },
      "cell_type": "markdown",
      "source": [
        "By default, Jax do not utilize the full GPU memory, but this can be overwritten. See [GPU memory allocation](https://docs.jax.dev/en/latest/gpu_memory_allocation.html):"
      ]
    },
    {
      "metadata": {
        "id": "o4MidM--4L47"
      },
      "cell_type": "code",
      "source": [
        "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"]=\"1.00\""
      ],
      "outputs": [],
      "execution_count": null
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-kdAZkvOIryQ"
      },
      "source": [
        "## Initializing the model\n",
        "\n",
        "To use Gemma with LoRA, simply wrap any Gemma model in `gm.nn.LoRA`:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "x-BbrzCVIupV"
      },
      "outputs": [],
      "source": [
        "model = gm.nn.LoRA(\n",
        "    rank=4,\n",
        "    model=gm.nn.Gemma3_4B(text_only=True),\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hI3Lg07SJff4"
      },
      "source": [
        "Initialize the weights:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "1shC1DpiJfsw"
      },
      "outputs": [],
      "source": [
        "token_ids = jnp.zeros((1, 256,), dtype=jnp.int32)  # Create the (batch_size, seq_length)\n",
        "\n",
        "params = model.init(\n",
        "    jax.random.key(0),\n",
        "    token_ids,\n",
        ")\n",
        "\n",
        "params = params['params']"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "T3dWILqKKzG3"
      },
      "source": [
        "Inspect the params shape/structure. We can see LoRA weights have been added."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "LMq2Z9nXKcad"
      },
      "outputs": [],
      "source": [
        "treescope.show(params)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "bGJl5YpKKOf-"
      },
      "source": [
        "Restore the pre-trained params. We use `peft.split_params` and `peft.merge_params` to replace the randomly initialized params with the pre-trained ones.\n",
        "\n",
        "When using `gm.ckpts.load_params`, make sure to pass the `params=original` kwarg. This ensure that:\n",
        "\n",
        "* The memory from the old params is released (so only a single copy of the weights stays in memory)\n",
        "* The restored params reuse the same sharding as the input (here there's no sharding, so isn't required)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "AcO6oBuLKNjb"
      },
      "outputs": [],
      "source": [
        "# Splits the params into non-LoRA and LoRA weights\n",
        "original, lora = peft.split_params(params)\n",
        "\n",
        "# Load the params from the checkpoint\n",
        "original = gm.ckpts.load_params(gm.ckpts.CheckpointPath.GEMMA3_4B_IT, params=original)\n",
        "\n",
        "# Merge the pretrained params back with LoRA\n",
        "params = peft.merge_params(original, lora)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6jYRSpv2IwIY"
      },
      "source": [
        "## Fine-tuning\n",
        "\n",
        "See our [finetuning guide](https://github.com/google-deepmind/gemma/blob/main/docs/lora_finetuning.md) for more info.\n",
        "\n",
        "For a end-to-end finetuning example, see our [lora.py](https://github.com/google-deepmind/gemma/tree/main/examples/lora.py) config."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MvsQbQM4I4Cs"
      },
      "source": [
        "## Inference\n",
        "\n",
        "Here's an example of running a single model call:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eqU7a4eCI5Wr"
      },
      "outputs": [],
      "source": [
        "tokenizer = gm.text.Gemma3Tokenizer()\n",
        "\n",
        "prompt = tokenizer.encode('The capital of France is')\n",
        "prompt = jnp.asarray([tokenizer.special_tokens.BOS] + prompt)\n",
        "\n",
        "\n",
        "# Run the model\n",
        "out = model.apply(\n",
        "    {'params': params},\n",
        "    tokens=prompt,\n",
        "    return_last_only=True,  # Only predict the last token\n",
        ")\n",
        "\n",
        "\n",
        "# Show the token distribution\n",
        "tokenizer.plot_logits(out.logits)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6dOSL9MHuMUa"
      },
      "source": [
        "To sample an entire sentence:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_ckwREdyqown"
      },
      "outputs": [],
      "source": [
        "sampler = gm.text.ChatSampler(\n",
        "    model=model,\n",
        "    params=params,\n",
        "    tokenizer=tokenizer,\n",
        ")\n",
        "\n",
        "sampler.chat('The capital of France is?')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RtAkaJTeuR0w"
      },
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "last_runtime": {},
      "private_outputs": true,
      "provenance": []
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
